{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b2b15f2e",
   "metadata": {},
   "source": [
    "# Run a trained policy\n",
    "\n",
    "This notebook will provide examples on how to run a trained policy and visualize the rollout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "000a4ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import json\n",
    "import h5py\n",
    "import imageio\n",
    "import numpy as np\n",
    "import os\n",
    "from copy import deepcopy\n",
    "\n",
    "import torch\n",
    "\n",
    "import robomimic\n",
    "import robomimic.utils.file_utils as FileUtils\n",
    "import robomimic.utils.torch_utils as TorchUtils\n",
    "import robomimic.utils.tensor_utils as TensorUtils\n",
    "import robomimic.utils.obs_utils as ObsUtils\n",
    "from robomimic.envs.env_base import EnvBase\n",
    "from robomimic.algo import RolloutPolicy\n",
    "\n",
    "import urllib.request\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47427159",
   "metadata": {},
   "source": [
    "### Download policy checkpoint\n",
    "First, let's try downloading a pretrained model from our model zoo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dfdfe5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get pretrained checkpooint from the model zoo\n",
    "\n",
    "ckpt_path = \"lift_ph_low_dim_epoch_1000_succ_100.pth\"\n",
    "# Lift (Proficient Human)\n",
    "urllib.request.urlretrieve(\n",
    "    \"http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/lift_ph_low_dim_epoch_1000_succ_100.pth\",\n",
    "    filename=ckpt_path\n",
    ")\n",
    "\n",
    "assert os.path.exists(ckpt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c2c25c6",
   "metadata": {},
   "source": [
    "### Loading trained policy\n",
    "We have a convenient function called `policy_from_checkpoint` that takes care of building the correct model from the checkpoint and load the trained weights. Of course you could also load the checkpoint manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf84aed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = TorchUtils.get_torch_device(try_to_use_cuda=True)\n",
    "\n",
    "# restore policy\n",
    "policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2872a3f0",
   "metadata": {},
   "source": [
    "### Creating rollout envionment\n",
    "The policy checkpoint also contains sufficient information to recreate the environment that it's trained with. Again, you may manually create the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d00c2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create environment from saved checkpoint\n",
    "env, _ = FileUtils.env_from_checkpoint(\n",
    "    ckpt_dict=ckpt_dict, \n",
    "    render=False, # we won't do on-screen rendering in the notebook\n",
    "    render_offscreen=True, # render to RGB images for video\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7ac0e9f",
   "metadata": {},
   "source": [
    "### Define the rollout loop\n",
    "Now let's define the main rollout loop. The loop runs the policy to a target `horizon` and optionally writes the rollout to a video."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dd1375e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):\n",
    "    \"\"\"\n",
    "    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, \n",
    "    and returns the rollout trajectory.\n",
    "    Args:\n",
    "        policy (instance of RolloutPolicy): policy loaded from a checkpoint\n",
    "        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata\n",
    "        horizon (int): maximum horizon for the rollout\n",
    "        render (bool): whether to render rollout on-screen\n",
    "        video_writer (imageio writer): if provided, use to write rollout to video\n",
    "        video_skip (int): how often to write video frames\n",
    "        camera_names (list): determines which camera(s) are used for rendering. Pass more than\n",
    "            one to output a video with multiple camera views concatenated horizontally.\n",
    "    Returns:\n",
    "        stats (dict): some statistics for the rollout - such as return, horizon, and task success\n",
    "    \"\"\"\n",
    "    assert isinstance(env, EnvBase)\n",
    "    assert isinstance(policy, RolloutPolicy)\n",
    "    assert not (render and (video_writer is not None))\n",
    "\n",
    "    policy.start_episode()\n",
    "    obs = env.reset()\n",
    "    state_dict = env.get_state()\n",
    "\n",
    "    # hack that is necessary for robosuite tasks for deterministic action playback\n",
    "    obs = env.reset_to(state_dict)\n",
    "\n",
    "    results = {}\n",
    "    video_count = 0  # video frame counter\n",
    "    total_reward = 0.\n",
    "    try:\n",
    "        for step_i in range(horizon):\n",
    "\n",
    "            # get action from policy\n",
    "            act = policy(ob=obs)\n",
    "\n",
    "            # play action\n",
    "            next_obs, r, done, _ = env.step(act)\n",
    "\n",
    "            # compute reward\n",
    "            total_reward += r\n",
    "            success = env.is_success()[\"task\"]\n",
    "\n",
    "            # visualization\n",
    "            if render:\n",
    "                env.render(mode=\"human\", camera_name=camera_names[0])\n",
    "            if video_writer is not None:\n",
    "                if video_count % video_skip == 0:\n",
    "                    video_img = []\n",
    "                    for cam_name in camera_names:\n",
    "                        video_img.append(env.render(mode=\"rgb_array\", height=512, width=512, camera_name=cam_name))\n",
    "                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally\n",
    "                    video_writer.append_data(video_img)\n",
    "                video_count += 1\n",
    "\n",
    "            # break if done or if success\n",
    "            if done or success:\n",
    "                break\n",
    "\n",
    "            # update for next iter\n",
    "            obs = deepcopy(next_obs)\n",
    "            state_dict = env.get_state()\n",
    "\n",
    "    except env.rollout_exceptions as e:\n",
    "        print(\"WARNING: got rollout exception {}\".format(e))\n",
    "\n",
    "    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))\n",
    "\n",
    "    return stats\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b43d371",
   "metadata": {},
   "source": [
    "### Run the policy\n",
    "Now let's rollout the policy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be6e1878",
   "metadata": {},
   "outputs": [],
   "source": [
    "rollout_horizon = 400\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "video_path = \"rollout.mp4\"\n",
    "video_writer = imageio.get_writer(video_path, fps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fa67efe",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = rollout(\n",
    "    policy=policy, \n",
    "    env=env, \n",
    "    horizon=rollout_horizon, \n",
    "    render=False, \n",
    "    video_writer=video_writer, \n",
    "    video_skip=5, \n",
    "    camera_names=[\"agentview\"]\n",
    ")\n",
    "print(stats)\n",
    "video_writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe79bc19",
   "metadata": {},
   "source": [
    "### Visualize the rollout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97472b37",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video\n",
    "Video(video_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
